package cn.bugstack.openai.executor.model.chatglm;

import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.chatglm.config.ChatGLMConfig;
import cn.bugstack.openai.executor.model.chatglm.valobj.CharGLMCompletionRequest;
import cn.bugstack.openai.executor.model.chatglm.valobj.ChatGLMCompletionRequest;
import cn.bugstack.openai.executor.model.chatglm.valobj.ChatGLMCompletionResponse;
import cn.bugstack.openai.executor.model.chatglm.valobj.Model;
import cn.bugstack.openai.executor.parameter.*;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.chatplus.application.common.util.PlusJsonUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.List;

import static cn.bugstack.openai.executor.parameter.CompletionRequest.Role.*;

/**
 * @author: ZhangZhe
 * @description: 超拟人大模型执行器
 */
@Slf4j
public class CharGLMModelExecutor implements Executor, ParameterHandler<CharGLMCompletionRequest>, ResultHandler {

    /**
     * 配置信息
     */
    private final ChatGLMConfig chatGLMConfig;
    /**
     * 工厂事件
     */
    private final EventSource.Factory factory;

    public CharGLMModelExecutor(Configuration configuration) {
        this.chatGLMConfig = configuration.getChatGLMConfig();
        this.factory = configuration.createRequestFactory();
    }

    /**
     * 问答模式，流式反馈
     *
     * @param completionRequest   请求信息
     * @param eventSourceListener 实现监听；通过监听的 onEvent 方法接收数据
     * @return 应答结果
     * @throws Exception 异常
     */
    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        //转换参数信息
        CharGLMCompletionRequest charGLMCompletionRequest = getParameterObject(completionRequest);
        //构建请求信息
        Request request = new Request.Builder()
                .tag(completionRequest.getTag())
                .url(ChatGLMConfig.API_HOST.concat(ChatGLMConfig.V3_COMPLETIONS).replace("{model}", completionRequest.getModel()))
                .post(RequestBody.create(PlusJsonUtils.toJsonString(charGLMCompletionRequest),MediaType.parse(Configuration.APPLICATION_JSON)))
                .build();
        // 3. 返回事件结果
        return factory.newEventSource(request, eventSourceListener(eventSourceListener));
    }

    /**
     * 问答模式，流式反馈 & 接收用户自定义 apiHost、apiKey - 适用于每个用户都有自己独立配置的场景
     *
     * @param apiHostByUser       apiHost
     * @param apiKeyByUser        apiKey
     * @param completionRequest   请求信息
     * @param eventSourceListener 实现监听；通过监听的 onEvent 方法接收数据
     * @return 应答结果
     * @throws Exception 异常
     */
    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        //构建请求信息
        Request request = new Request.Builder()
                .tag(completionRequest.getTag())
                .url(ChatGLMConfig.API_HOST.concat(ChatGLMConfig.V3_COMPLETIONS).replace("{model}", completionRequest.getModel()))
                .post(RequestBody.create(PlusJsonUtils.toJsonString(getParameterObject(completionRequest)), MediaType.parse(Configuration.APPLICATION_JSON)))
                .build();
        // 3. 返回事件结果
        return factory.newEventSource(request, eventSourceListener(eventSourceListener));
    }

    /**
     * 生成图片
     *
     * @param imageRequest 图片描述
     * @return 应答结果
     */
    @Override
    public ImageResponse genImages(ImageRequest imageRequest) throws Exception {
        return null;
    }

    /**
     * 生成图片
     *
     * @param apiHostByUser apiHost
     * @param apiKeyByUser  apiKey
     * @param imageRequest  图片描述
     * @return 应答结果
     */
    @Override
    public ImageResponse genImages(String apiHostByUser, String apiKeyByUser, ImageRequest imageRequest) throws Exception {
        return null;
    }

    /**
     * 图片理解
     *
     * @param pictureRequest      图片和对图片的描述
     * @param eventSourceListener
     * @return 应答结果
     * @throws Exception
     */
    @Override
    public EventSource pictureUnderstanding(PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    /**
     * 图片理解
     *
     * @param apiHostByUser
     * @param apiKeyByUser
     * @param pictureRequest      图片和对图片的描述
     * @param eventSourceListener
     * @return 应答结果
     * @throws Exception
     */
    @Override
    public EventSource pictureUnderstanding(String apiHostByUser, String apiKeyByUser, PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public CharGLMCompletionRequest getParameterObject(CompletionRequest completionRequest) {
        CharGLMCompletionRequest request = new CharGLMCompletionRequest();
        request.setModel(Model.getModel(completionRequest.getModel()));
        request.setTemperature(completionRequest.getTemperature());
        request.setTopP(completionRequest.getTopP());
        List<ChatGLMCompletionRequest.Prompt> prompts = new ArrayList<>();
        request.setMeta(CharGLMCompletionRequest.Meta.builder().build());
        List<Message> messages = completionRequest.getMessages();
        for (Message message : messages) {
            String role = message.getRole();
            if (StringUtils.equals(role, USER_INFO.getCode())) {
                request.getMeta().setUserInfo(message.getContent());
            } else if (StringUtils.equals(role, BOT_INFO.getCode())) {
                request.getMeta().setBotInfo(message.getContent());
            } else if (StringUtils.equals(role, USER_NAME.getCode())) {
                request.getMeta().setUserName(message.getContent());
            } else if (StringUtils.equals(role, BOT_NAME.getCode())) {
                request.getMeta().setBotName(message.getContent());
            } else {
                prompts.add(ChatGLMCompletionRequest.Prompt.builder()
                        .role(message.getRole())
                        .content(message.getContent())
                        .build());
            }
        }
        request.setMessages(prompts);
        return request;
    }

    @Override
    public boolean supportFunction(String model) {
        return false;
    }

    @Override
    public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
        return new EventSourceListener() {
            @Override
            public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
                if ("[DONE]".equals(data)) {
                    return;
                }
                ChatGLMCompletionResponse response = PlusJsonUtils.parseObject(data, ChatGLMCompletionResponse.class);
                if (null == response) {
                    return;
                }
                // 构建结果
                CompletionResponse completionResponse = new CompletionResponse();
                completionResponse.setChoices(response.getChoices());
                completionResponse.setUsage(response.getUsage());
                completionResponse.setCreated(System.currentTimeMillis());
                // 返回数据
                eventSourceListener.onEvent(eventSource, id, type, PlusJsonUtils.toJsonString(completionResponse));
            }


            @Override
            public void onClosed(EventSource eventSource) {
                eventSourceListener.onClosed(eventSource);
            }

            @Override
            public void onOpen(EventSource eventSource, Response response) {
                eventSourceListener.onOpen(eventSource, response);
            }

            @Override
            public void onFailure(EventSource eventSource, @javax.annotation.Nullable Throwable t, @javax.annotation.Nullable Response response) {
                eventSourceListener.onFailure(eventSource, t, response);
            }
        };
    }
}
